import torch

from torch_geometric.datasets import JODIEDataset
from .label_prediction_sequence import LabelSequencePrediction
from .pascalvoc import TemporalPascalVOC
from numpy.random import default_rng
import numpy


PASCALVOC = "Temporal_PascalVOC-SP_1024"
JODIE = ['Wikipedia', "Reddit", "MOOC", "LastFM"]

SYNTHETIC_SEQUENCE = [
    "LabelSequence_3_1000", "LabelSequence_5_1000", "LabelSequence_7_1000", 
    "LabelSequence_9_1000", "LabelSequence_11_1000", "LabelSequence_15_1000",
    "LabelSequence_20_1000"
]

DATA_NAMES = JODIE + SYNTHETIC_SEQUENCE + [PASCALVOC]

def get_dataset(root, name, seed):
    rng = default_rng(seed)
    if name in JODIE:
        dataset = JODIEDataset(root, name)
        data = dataset[0]
        data.x = torch.tensor(rng.random((data.num_nodes,1), dtype=numpy.float32))
        num_nodes, edge_dim = data.num_nodes, data.msg.shape[-1] 
        node_dim = data.x.shape[-1] #if hasattr(data, 'x') else 0
        init_time = data.t[0]
        out_dim = 1

    elif name in SYNTHETIC_SEQUENCE:
        (data, num_nodes, edge_dim, node_dim,
            out_dim, init_time) = synthetic(root, name, seed)
        
    elif PASCALVOC == name:
        data = TemporalPascalVOC(root=root, name=name)
        num_nodes = data.num_nodes
        edge_dim, node_dim = data.train_data.msg.shape[-1], data.train_data.x.shape[-1]
        init_time = data.train_data.t[0]
        out_dim = max(data.train_data.y.max(), data.val_data.y.max(), data.test_data.y.max()) + 1 
        assert out_dim == 21, "The number of classes is exactly 21 for this dataset, this should not happen"
    else:
        raise NotImplementedError
    
    return data, num_nodes, edge_dim, node_dim, out_dim, init_time


def synthetic(root, name, seed):
    spl = name.split('_')
    num_seq = int(spl[-1])
    seq_len = int(spl[-2])

    if 'LabelSequence' in name: 
        data = LabelSequencePrediction(root=root, name=name, seq_len=seq_len, 
                                       num_seq=num_seq, seed=seed)
        num_nodes = data.num_nodes
        edge_dim, node_dim = data.data[0].msg.shape[-1], data.data[0].x.shape[-1]
        init_time = data.data[0].t[0]
        out_dim = data.data[0].y.shape[-1]
        assert out_dim == 1
    else:
        raise ValueError(f'The name is not in {SYNTHETIC_SEQUENCE}. Got {name}')
    
    return data, num_nodes, edge_dim, node_dim, out_dim, init_time